#!/usr/bin/env python
# Copyright 2014-2020 The PySCF Developers. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Author: Qiming Sun <osirpt.sun@gmail.com>
#

'''
Hartree-Fock
'''

import sys
import tempfile

from functools import reduce
import numpy
import scipy.linalg
import h5py
from pyscf import gto
from pyscf import lib
from pyscf.lib import logger
from pyscf.scf import diis
from pyscf.scf import _vhf
from pyscf.scf import chkfile
from pyscf.scf import dispersion
from pyscf.data import nist
from pyscf import __config__


WITH_META_LOWDIN = getattr(__config__, 'scf_analyze_with_meta_lowdin', True)
PRE_ORTH_METHOD = getattr(__config__, 'scf_analyze_pre_orth_method', 'ANO')
MO_BASE = getattr(__config__, 'MO_BASE', 1)
TIGHT_GRAD_CONV_TOL = getattr(__config__, 'scf_hf_kernel_tight_grad_conv_tol', True)
MUTE_CHKFILE = getattr(__config__, 'scf_hf_SCF_mute_chkfile', False)

def kernel(mf, conv_tol=1e-10, conv_tol_grad=None,
           dump_chk=True, dm0=None, callback=None, conv_check=True, **kwargs):
    '''kernel: the SCF driver.

    Args:
        mf : an instance of SCF class
            mf object holds all parameters to control SCF.  One can modify its
            member functions to change the behavior of SCF.  The member
            functions which are called in kernel are

            | mf.get_init_guess
            | mf.get_hcore
            | mf.get_ovlp
            | mf.get_veff
            | mf.get_fock
            | mf.get_grad
            | mf.eig
            | mf.get_occ
            | mf.make_rdm1
            | mf.energy_tot
            | mf.dump_chk

    Kwargs:
        conv_tol : float
            converge threshold.
        conv_tol_grad : float
            gradients converge threshold.
        dump_chk : bool
            Whether to save SCF intermediate results in the checkpoint file
        dm0 : ndarray
            Initial guess density matrix.  If not given (the default), the kernel
            takes the density matrix generated by ``mf.get_init_guess``.
        callback : function(envs_dict) => None
            callback function takes one dict as the argument which is
            generated by the builtin function :func:`locals`, so that the
            callback function can access all local variables in the current
            environment.
        sap_basis : str
            SAP basis name

    Returns:
        A list :   scf_conv, e_tot, mo_energy, mo_coeff, mo_occ

        scf_conv : bool
            True means SCF converged
        e_tot : float
            Hartree-Fock energy of last iteration
        mo_energy : 1D float array
            Orbital energies.  Depending the eig function provided by mf
            object, the orbital energies may NOT be sorted.
        mo_coeff : 2D array
            Orbital coefficients.
        mo_occ : 1D array
            Orbital occupancies.  The occupancies may NOT be sorted from large
            to small.

    Examples:

    >>> from pyscf import gto, scf
    >>> mol = gto.M(atom='H 0 0 0; H 0 0 1.1', basis='cc-pvdz')
    >>> conv, e, mo_e, mo, mo_occ = scf.hf.kernel(scf.hf.SCF(mol), dm0=numpy.eye(mol.nao_nr()))
    >>> print('conv = %s, E(HF) = %.12f' % (conv, e))
    conv = True, E(HF) = -1.081170784378
    '''
    if 'init_dm' in kwargs:
        raise RuntimeError('''
You see this error message because of the API updates in pyscf v0.11.
Keyword argument "init_dm" is replaced by "dm0"''')
    cput0 = (logger.process_clock(), logger.perf_counter())
    if conv_tol_grad is None:
        conv_tol_grad = numpy.sqrt(conv_tol)
        logger.info(mf, 'Set gradient conv threshold to %g', conv_tol_grad)

    mol = mf.mol
    s1e = mf.get_ovlp(mol)

    if dm0 is None:
        dm = mf.get_init_guess(mol, mf.init_guess, s1e=s1e, **kwargs)
    else:
        dm = dm0

    h1e = mf.get_hcore(mol)
    vhf = mf.get_veff(mol, dm)
    e_tot = mf.energy_tot(dm, h1e, vhf)
    logger.info(mf, 'init E= %.15g', e_tot)

    scf_conv = False
    mo_energy = mo_coeff = mo_occ = None

    # Skip SCF iterations. Compute only the total energy of the initial density
    if mf.max_cycle <= 0:
        fock = mf.get_fock(h1e, s1e, vhf, dm)  # = h1e + vhf, no DIIS
        mo_energy, mo_coeff = mf.eig(fock, s1e)
        mo_occ = mf.get_occ(mo_energy, mo_coeff)
        return scf_conv, e_tot, mo_energy, mo_coeff, mo_occ

    if isinstance(mf.diis, lib.diis.DIIS):
        mf_diis = mf.diis
    elif mf.diis:
        assert issubclass(mf.DIIS, lib.diis.DIIS)
        mf_diis = mf.DIIS(mf, mf.diis_file)
        mf_diis.space = mf.diis_space
        mf_diis.rollback = mf.diis_space_rollback
        mf_diis.damp = mf.diis_damp

        # We get the used orthonormalized AO basis from any old eigendecomposition.
        # Since the ingredients for the Fock matrix has already been built, we can
        # just go ahead and use it to determine the orthonormal basis vectors.
        fock = mf.get_fock(h1e, s1e, vhf, dm)
        _, mf_diis.Corth = mf.eig(fock, s1e)
    else:
        mf_diis = None

    if dump_chk and mf.chkfile:
        # Explicit overwrite the mol object in chkfile
        # Note in pbc.scf, mf.mol == mf.cell, cell is saved under key "mol"
        chkfile.save_mol(mol, mf.chkfile)

    # A preprocessing hook before the SCF iteration
    mf.pre_kernel(locals())

    fock_last = None
    cput1 = logger.timer(mf, 'initialize scf', *cput0)
    mf.cycles = 0
    for cycle in range(mf.max_cycle):
        dm_last = dm
        last_hf_e = e_tot

        fock = mf.get_fock(h1e, s1e, vhf, dm, cycle, mf_diis, fock_last=fock_last)
        mo_energy, mo_coeff = mf.eig(fock, s1e)
        mo_occ = mf.get_occ(mo_energy, mo_coeff)
        dm = mf.make_rdm1(mo_coeff, mo_occ)
        vhf = mf.get_veff(mol, dm, dm_last, vhf)
        e_tot = mf.energy_tot(dm, h1e, vhf)

        # Here Fock matrix is h1e + vhf, without DIIS.  Calling get_fock
        # instead of the statement "fock = h1e + vhf" because Fock matrix may
        # be modified in some methods.
        fock_last = fock
        fock = mf.get_fock(h1e, s1e, vhf, dm)  # = h1e + vhf, no DIIS
        norm_gorb = numpy.linalg.norm(mf.get_grad(mo_coeff, mo_occ, fock))
        if not TIGHT_GRAD_CONV_TOL:
            norm_gorb = norm_gorb / numpy.sqrt(norm_gorb.size)
        norm_ddm = numpy.linalg.norm(dm-dm_last)
        logger.info(mf, 'cycle= %d E= %.15g  delta_E= %4.3g  |g|= %4.3g  |ddm|= %4.3g',
                    cycle+1, e_tot, e_tot-last_hf_e, norm_gorb, norm_ddm)

        if callable(mf.check_convergence):
            scf_conv = mf.check_convergence(locals())
        elif abs(e_tot-last_hf_e) < conv_tol and norm_gorb < conv_tol_grad:
            scf_conv = True

        if dump_chk and mf.chkfile:
            mf.dump_chk(locals())

        if callable(callback):
            callback(locals())

        cput1 = logger.timer(mf, 'cycle= %d'%(cycle+1), *cput1)

        if scf_conv:
            break

    mf.cycles = cycle + 1
    if scf_conv and conv_check:
        # An extra diagonalization, to remove level shift
        #fock = mf.get_fock(h1e, s1e, vhf, dm)  # = h1e + vhf
        mo_energy, mo_coeff = mf.eig(fock, s1e)
        mo_occ = mf.get_occ(mo_energy, mo_coeff)
        dm, dm_last = mf.make_rdm1(mo_coeff, mo_occ), dm
        vhf = mf.get_veff(mol, dm, dm_last, vhf)
        e_tot, last_hf_e = mf.energy_tot(dm, h1e, vhf), e_tot

        fock = mf.get_fock(h1e, s1e, vhf, dm)
        norm_gorb = numpy.linalg.norm(mf.get_grad(mo_coeff, mo_occ, fock))
        if not TIGHT_GRAD_CONV_TOL:
            norm_gorb = norm_gorb / numpy.sqrt(norm_gorb.size)
        norm_ddm = numpy.linalg.norm(dm-dm_last)

        conv_tol = conv_tol * 10
        conv_tol_grad = conv_tol_grad * 3
        if callable(mf.check_convergence):
            scf_conv = mf.check_convergence(locals())
        elif abs(e_tot-last_hf_e) < conv_tol or norm_gorb < conv_tol_grad:
            scf_conv = True
        else:
            scf_conv = False
        logger.info(mf, 'Extra cycle  E= %.15g  delta_E= %4.3g  |g|= %4.3g  |ddm|= %4.3g',
                    e_tot, e_tot-last_hf_e, norm_gorb, norm_ddm)
        if dump_chk and mf.chkfile:
            mf.dump_chk(locals())

    logger.timer(mf, 'scf_cycle', *cput0)
    # A post-processing hook before return
    mf.post_kernel(locals())
    return scf_conv, e_tot, mo_energy, mo_coeff, mo_occ


def energy_elec(mf, dm=None, h1e=None, vhf=None):
    r'''Electronic part of Hartree-Fock energy, for given core hamiltonian and
    HF potential

    .. math::

        E = \sum_{ij}h_{ij} \gamma_{ji}
          + \frac{1}{2}\sum_{ijkl} \gamma_{ji}\gamma_{lk} \langle ik||jl\rangle

    Note this function has side effects which cause mf.scf_summary updated.

    Args:
        mf : an instance of SCF class

    Kwargs:
        dm : 2D ndarray
            one-particle density matrix
        h1e : 2D ndarray
            Core hamiltonian
        vhf : 2D ndarray
            HF potential

    Returns:
        Hartree-Fock electronic energy and the Coulomb energy

    Examples:

    >>> from pyscf import gto, scf
    >>> mol = gto.M(atom='H 0 0 0; H 0 0 1.1')
    >>> mf = scf.RHF(mol)
    >>> mf.scf()
    >>> dm = mf.make_rdm1()
    >>> scf.hf.energy_elec(mf, dm)
    (-1.5176090667746334, 0.60917167853723675)
    >>> mf.energy_elec(dm)
    (-1.5176090667746334, 0.60917167853723675)
    '''
    if dm is None: dm = mf.make_rdm1()
    if h1e is None: h1e = mf.get_hcore()
    if vhf is None: vhf = mf.get_veff(mf.mol, dm)
    e1 = numpy.einsum('ij,ji->', h1e, dm).real
    e_coul = numpy.einsum('ij,ji->', vhf, dm).real * .5
    mf.scf_summary['e1'] = e1
    mf.scf_summary['e2'] = e_coul
    logger.debug(mf, 'E1 = %s  E_coul = %s', e1, e_coul)
    return e1+e_coul, e_coul


def energy_tot(mf, dm=None, h1e=None, vhf=None):
    r'''Total Hartree-Fock energy, electronic part plus nuclear repulsion
    See :func:`scf.hf.energy_elec` for the electron part

    Note this function has side effects which cause mf.scf_summary updated.

    '''
    nuc = mf.energy_nuc()
    mf.scf_summary['nuc'] = nuc.real

    e_tot = mf.energy_elec(dm, h1e, vhf)[0] + nuc
    if mf.do_disp():
        if 'dispersion' in mf.scf_summary:
            e_tot += mf.scf_summary['dispersion']
        else:
            e_disp = mf.get_dispersion()
            mf.scf_summary['dispersion'] = e_disp
            e_tot += e_disp

    return e_tot


def get_hcore(mol):
    '''Core Hamiltonian

    Examples:

    >>> from pyscf import gto, scf
    >>> mol = gto.M(atom='H 0 0 0; H 0 0 1.1')
    >>> scf.hf.get_hcore(mol)
    array([[-0.93767904, -0.59316327],
           [-0.59316327, -0.93767904]])
    '''
    h = mol.intor_symmetric('int1e_kin')

    if mol._pseudo:
        # Although mol._pseudo for GTH PP is only available in Cell, GTH PP
        # may exist if mol is converted from cell object.
        from pyscf.gto import pp_int
        h += pp_int.get_gth_pp(mol)
    else:
        h+= mol.intor_symmetric('int1e_nuc')

    if len(mol._ecpbas) > 0:
        h += mol.intor_symmetric('ECPscalar')
    return h


def get_ovlp(mol):
    '''Overlap matrix
    '''
    return mol.intor_symmetric('int1e_ovlp')


def init_guess_by_minao(mol):
    '''Generate initial guess density matrix based on ANO basis, then project
    the density matrix to the basis set defined by ``mol``

    Returns:
        Density matrix, 2D ndarray

    Examples:

    >>> from pyscf import gto, scf
    >>> mol = gto.M(atom='H 0 0 0; H 0 0 1.1')
    >>> scf.hf.init_guess_by_minao(mol)
    array([[ 0.94758917,  0.09227308],
           [ 0.09227308,  0.94758917]])
    '''
    from pyscf.scf import atom_hf
    from pyscf.scf import addons

    def minao_basis(symb, nelec_ecp):
        occ = []
        basis_ano = []
        if gto.is_ghost_atom(symb):
            return occ, basis_ano

        stdsymb = gto.mole._std_symbol(symb)
        basis_add = gto.basis.load('ano', stdsymb)
# coreshl defines the core shells to be removed in the initial guess
        coreshl = gto.ecp.core_configuration(nelec_ecp, atom_symbol=stdsymb)
        # coreshl = (0,0,0,0)  # it keeps all core electrons in the initial guess
        for l in range(4):
            ndocc, frac = atom_hf.frac_occ(stdsymb, l)
            if ndocc >= coreshl[l]:
                degen = l * 2 + 1
                occ_l = [2, ]*(ndocc-coreshl[l]) + [frac, ]
                occ.append(numpy.repeat(occ_l, degen))
                basis_ano.append([l] + [b[:1] + b[1+coreshl[l]:ndocc+2]
                                        for b in basis_add[l][1:]])
            else:
                logger.debug(mol, '*** ECP incorporates partially occupied '
                             'shell of l = %d for atom %s ***', l, symb)
        occ = numpy.hstack(occ)

        if nelec_ecp > 0:
            if symb in mol._basis:
                input_basis = mol._basis[symb]
            elif stdsymb in mol._basis:
                input_basis = mol._basis[stdsymb]
            else:
                raise KeyError(symb)

            basis4ecp = [[] for i in range(4)]
            for bas in input_basis:
                l = bas[0]
                if l < 4:
                    basis4ecp[l].append(bas)

            occ4ecp = []
            for l in range(4):
                nbas_l = sum((len(bas[1]) - 1) for bas in basis4ecp[l])
                ndocc, frac = atom_hf.frac_occ(stdsymb, l)
                ndocc -= coreshl[l]
                assert ndocc <= nbas_l

                if nbas_l > 0:
                    occ_l = numpy.zeros(nbas_l)
                    occ_l[:ndocc] = 2
                    if frac > 0:
                        occ_l[ndocc] = frac
                    occ4ecp.append(numpy.repeat(occ_l, l * 2 + 1))

            occ4ecp = numpy.hstack(occ4ecp)
            basis4ecp = lib.flatten(basis4ecp)

# Compared to ANO valence basis, to check whether the ECP basis set has
# reasonable AO-character contraction.  The ANO valence AO should have
# significant overlap to ECP basis if the ECP basis has AO-character.
            atm1 = gto.Mole()
            atm2 = gto.Mole()
            atom = [[symb, (0.,0.,0.)]]
            atm1._atm, atm1._bas, atm1._env = atm1.make_env(atom, {symb:basis4ecp}, [])
            atm2._atm, atm2._bas, atm2._env = atm2.make_env(atom, {symb:basis_ano}, [])
            atm1._built = True
            atm2._built = True
            s12 = gto.intor_cross('int1e_ovlp', atm1, atm2)
            if abs(numpy.linalg.det(s12[occ4ecp>0][:,occ>0])) > .1:
                occ, basis_ano = occ4ecp, basis4ecp
            else:
                logger.debug(mol, 'Density of valence part of ANO basis '
                             'will be used as initial guess for %s', symb)
        return occ, basis_ano

    # Issue 548
    if any(gto.charge(mol.atom_symbol(ia)) > 96 for ia in range(mol.natm)):
        logger.info(mol, 'MINAO initial guess is not available for super-heavy '
                    'elements. "atom" initial guess is used.')
        return init_guess_by_atom(mol)

    nelec_ecp_dic = {mol.atom_symbol(ia): mol.atom_nelec_core(ia)
                          for ia in range(mol.natm)}

    basis = {}
    occdic = {}
    for symb, nelec_ecp in nelec_ecp_dic.items():
        occ_add, basis_add = minao_basis(symb, nelec_ecp)
        occdic[symb] = occ_add
        basis[symb] = basis_add

    occ = []
    new_atom = []
    for ia in range(mol.natm):
        symb = mol.atom_symbol(ia)
        if not gto.is_ghost_atom(symb):
            occ.append(occdic[symb])
            new_atom.append(mol._atom[ia])
    occ = numpy.hstack(occ)

    pmol = gto.Mole()
    pmol._atm, pmol._bas, pmol._env = pmol.make_env(new_atom, basis, [])
    pmol._built = True

    #: dm = addons.project_dm_nr2nr(pmol, numpy.diag(occ), mol)
    mo = addons.project_mo_nr2nr(pmol, numpy.eye(pmol.nao), mol)
    dm = lib.dot(mo*occ, mo.conj().T)
# normalize electron number
#    s = mol.intor_symmetric('int1e_ovlp')
#    dm *= mol.nelectron / (dm*s).sum()
    return lib.tag_array(dm, mo_coeff=mo, mo_occ=occ)


def init_guess_by_1e(mol):
    '''Generate initial guess density matrix from core hamiltonian

    Returns:
        Density matrix, 2D ndarray
    '''
    mf = RHF(mol)
    return mf.init_guess_by_1e(mol)


def init_guess_by_atom(mol):
    '''Generate initial guess density matrix from superposition of atomic HF
    density matrix.  The atomic HF is occupancy averaged RHF

    Returns:
        Density matrix, 2D ndarray
    '''
    from pyscf.scf import atom_hf
    atm_scf = atom_hf.get_atm_nrhf(mol)
    aoslice = mol.aoslice_by_atom()
    atm_dms = []
    mo_coeff = []
    mo_occ = []
    for ia in range(mol.natm):
        symb = mol.atom_symbol(ia)
        if symb not in atm_scf:
            symb = mol.atom_pure_symbol(ia)

        if symb in atm_scf:
            e_hf, e, c, occ = atm_scf[symb]
        else:  # symb's basis is not specified in the input
            nao_atm = aoslice[ia,3] - aoslice[ia,2]
            c = numpy.zeros((nao_atm, nao_atm))
            occ = numpy.zeros(nao_atm)

        atm_dms.append(numpy.dot(c*occ, c.conj().T))
        mo_coeff.append(c)
        mo_occ.append(occ)

    dm = scipy.linalg.block_diag(*atm_dms)
    mo_coeff = scipy.linalg.block_diag(*mo_coeff)
    mo_occ = numpy.hstack(mo_occ)

    if mol.cart:
        cart2sph = mol.cart2sph_coeff(normalized='sp')
        dm = reduce(lib.dot, (cart2sph, dm, cart2sph.T))
        mo_coeff = lib.dot(cart2sph, mo_coeff)

    for k, v in atm_scf.items():
        logger.debug1(mol, 'Atom %s, E = %.12g', k, v[0])
    return lib.tag_array(dm, mo_coeff=mo_coeff, mo_occ=mo_occ)

def init_guess_by_huckel(mol):
    '''Generate initial guess density matrix from a Huckel calculation based
    on occupancy averaged atomic RHF calculations, doi:10.1021/acs.jctc.8b01089

    Returns:
        Density matrix, 2D ndarray
    '''
    mo_energy, mo_coeff = _init_guess_huckel_orbitals(mol, updated_rule = False)
    mo_occ = get_occ(SCF(mol), mo_energy, mo_coeff)
    return make_rdm1(mo_coeff, mo_occ)

def init_guess_by_mod_huckel(mol):
    '''Generate initial guess density matrix from a Huckel calculation based
    on occupancy averaged atomic RHF calculations, doi:10.1021/acs.jctc.8b01089

    In contrast to init_guess_by_huckel, this routine employs the
    updated GWH rule from doi:10.1021/ja00480a005 to form the guess.

    Returns:
        Density matrix, 2D ndarray

    '''
    mo_energy, mo_coeff = _init_guess_huckel_orbitals(mol, updated_rule = True)
    mo_occ = get_occ(SCF(mol), mo_energy, mo_coeff)
    return make_rdm1(mo_coeff, mo_occ)

def Kgwh(Ei, Ej, updated_rule=False):
    '''Computes the generalized Wolfsberg-Helmholtz parameter'''

    # GWH parameter value
    k = 1.75

    if updated_rule:
        '''Updated scheme from J. Am. Chem. Soc. 100, 3686 (1978); doi:10.1021/ja00480a005'''
        Delta = (Ei-Ej)/(Ei+Ej)
        return k + Delta**2 + Delta**4 * (1 - k)
    else:
        '''Original rule'''
        return k

def _init_guess_huckel_orbitals(mol, updated_rule = False):
    '''Generate initial guess density matrix from a Huckel calculation based
    on occupancy averaged atomic RHF calculations, doi:10.1021/acs.jctc.8b01089

    Arguments:
        mol, the molecule
        updated_rule, boolean triggering use of the updated GWH rule from doi:10.1021/ja00480a005

    Returns:
        An 1D array for Huckel orbital energies and an 2D array for orbital coefficients
    '''
    from pyscf.scf import atom_hf
    atm_scf = atom_hf.get_atm_nrhf(mol)

    # Run atomic SCF calculations to get orbital energies, coefficients and occupations
    at_e = []
    at_c = []
    at_occ = []
    for ia in range(mol.natm):
        symb = mol.atom_symbol(ia)
        if symb not in atm_scf:
            symb = mol.atom_pure_symbol(ia)
        e_hf, e, c, occ = atm_scf[symb]
        at_c.append(c)
        at_e.append(e)
        at_occ.append(occ)

    # Count number of occupied orbitals
    nocc = 0
    for ia in range(mol.natm):
        for iorb in range(len(at_occ[ia])):
            if (at_occ[ia][iorb]>0.0):
                nocc=nocc+1

    # Number of basis functions
    nbf = mol.nao_nr()
    # Collect AO coefficients and energies
    orb_E = numpy.zeros(nocc)
    orb_C = numpy.zeros((nbf,nocc))

    # Atomic basis info
    aoslice = mol.aoslice_by_atom()

    # Atomic cartesian mappings
    atcart2sph = None
    if mol.cart:
        atcart2sph = []
        molcart2sph = mol.cart2sph_coeff(normalized='sp')
        for ia in range(mol.natm):
            # First and last bf index
            abeg = aoslice[ia, 2]
            aend = aoslice[ia, 3]
            # Atomic slice
            atsph = molcart2sph[abeg:aend,:]
            # Find the columns with nonzero entries on the atom
            colnorm = numpy.asarray([numpy.linalg.norm(atsph[:,i]) for i in range(atsph.shape[1])])
            atcart2sph.append(atsph[:,colnorm!=0.0])

    iocc = 0
    for ia in range(mol.natm):
        # First and last bf index
        abeg = aoslice[ia, 2]
        aend = aoslice[ia, 3]

        for iorb in range(len(at_occ[ia])):
            if (at_occ[ia][iorb]>0.0):
                if mol.cart:
                    orb_C[abeg:aend,iocc] = numpy.dot(at_c[ia][:,iorb], atcart2sph[ia].T)
                else:
                    orb_C[abeg:aend,iocc] = at_c[ia][:,iorb]
                orb_E[iocc] = at_e[ia][iorb]
                iocc=iocc+1

    # Overlap matrix
    S = get_ovlp(mol)
    # Atomic orbital overlap
    orb_S = orb_C.transpose().dot(S).dot(orb_C)

    # Build Huckel matrix
    orb_H = numpy.zeros((nocc,nocc))
    for io in range(nocc):
        # Diagonal is just the orbital energies
        orb_H[io,io] = orb_E[io]
        for jo in range(io):
            # Off-diagonal is given by GWH approximation
            orb_H[io,jo] = 0.5*Kgwh(orb_E[io],orb_E[jo],updated_rule=updated_rule)*orb_S[io,jo]*(orb_E[io]+orb_E[jo])
            orb_H[jo,io] = orb_H[io,jo]

    # Energies and coefficients in the minimal orbital basis
    mo_E, atmo_C = eig(orb_H, orb_S)
    # and in the AO basis
    mo_C = orb_C.dot(atmo_C)

    return mo_E, mo_C


def init_guess_by_chkfile(mol, chkfile_name, project=None):
    '''Read the HF results from checkpoint file, then project it to the
    basis defined by ``mol``

    Kwargs:
        project : None or bool
            Whether to project chkfile's orbitals to the new basis.  Note when
            the geometry of the chkfile and the given molecule are very
            different, this projection can produce very poor initial guess.
            In PES scanning, it is recommended to switch off project.

            If project is set to None, the projection is only applied when the
            basis sets of the chkfile's molecule are different to the basis
            sets of the given molecule (regardless whether the geometry of
            the two molecules are different).  Note the basis sets are
            considered to be different if the two molecules are derived from
            the same molecule with different ordering of atoms.

    Returns:
        Density matrix, 2D ndarray
    '''
    from pyscf.scf import addons
    chk_mol, scf_rec = chkfile.load_scf(chkfile_name)
    if project is None:
        project = not gto.same_basis_set(chk_mol, mol)

    # Check whether the two molecules are similar
    im1 = scipy.linalg.eigvalsh(mol.inertia_moment())
    im2 = scipy.linalg.eigvalsh(chk_mol.inertia_moment())
    # im1+1e-7 to avoid 'divide by zero' error
    if abs((im1-im2)/(im1+1e-7)).max() > 0.01:
        logger.warn(mol, "Large deviations found between the input "
                    "molecule and the molecule from chkfile\n"
                    "Initial guess density matrix may have large error.")

    if project:
        s = get_ovlp(mol)

    def fproj(mo):
        if project:
            mo = addons.project_mo_nr2nr(chk_mol, mo, mol)
            norm = numpy.einsum('pi,pi->i', mo.conj(), s.dot(mo))
            mo /= numpy.sqrt(norm)
        return mo

    mo = scf_rec['mo_coeff']
    mo_occ = scf_rec['mo_occ']
    if getattr(mo[0], 'ndim', None) == 1:  # RHF
        if numpy.iscomplexobj(mo):
            raise NotImplementedError('TODO: project DHF orbital to UHF orbital')
        mo_coeff = fproj(mo)
        dm = make_rdm1(mo_coeff, mo_occ)
    else:  #UHF
        if getattr(mo[0][0], 'ndim', None) == 2:  # KUHF
            logger.warn(mol, 'k-point UHF results are found.  Density matrix '
                        'at Gamma point is used for the molecular SCF initial guess')
            mo = mo[0]
        dma = make_rdm1(fproj(mo[0]), mo_occ[0])
        dmb = make_rdm1(fproj(mo[1]), mo_occ[1])
        dm = dma + dmb
        s = get_ovlp(mol)
        _, mo_coeff = scipy.linalg.eigh(dm, s, type=2)
        dm = lib.tag_array(dm, mo_coeff=mo_coeff[:,::-1], mo_occ=mo_occ)
    return dm

def init_guess_by_sap(mol, sap_basis, **kwargs):
    '''Generate initial guess density matrix from a superposition of
    atomic potentials (SAP), doi:10.1021/acs.jctc.8b01089.
    This is the Gaussian fit implementation, see doi:10.1063/5.0004046.

    Args:
        mol : MoleBase object
            the molecule object for which the initial guess is evaluated
        sap_basis : dict
            SAP basis in internal format (python dictionary)

    Returns:
        dm0 : ndarray
            SAP initial guess density matrix
    '''
    Vsap = make_sap(mol, sap_basis=sap_basis)
    hcore = get_hcore(mol)
    s = get_ovlp(mol)
    e, coeff = eig(hcore + Vsap, s)

    mf = RHF(mol)
    occ = get_occ(mf, e, coeff)

    dm = make_rdm1(coeff, occ)
    return dm

def get_init_guess(mol, key='minao', **kwargs):
    '''Generate density matrix for initial guess

    Kwargs:
        key : str
            One of 'minao', 'atom', 'huckel', 'hcore', '1e', 'sap', 'chkfile'.
    '''
    return RHF(mol).get_init_guess(mol, key, **kwargs)


# eigenvalue of d is 1
def level_shift(s, d, f, factor):
    r'''Apply level shift :math:`\Delta` to virtual orbitals

    .. math::
       :nowrap:

       \begin{align}
         FC &= SCE \\
         F &= F + SC \Lambda C^\dagger S \\
         \Lambda_{ij} &=
         \begin{cases}
            \delta_{ij}\Delta & i \in \text{virtual} \\
            0 & \text{otherwise}
         \end{cases}
       \end{align}

    Returns:
        New Fock matrix, 2D ndarray
    '''
    dm_vir = s - reduce(lib.dot, (s, d, s))
    return f + dm_vir * factor


def damping(f, f_prev, factor):
    return f*(1-factor) + f_prev*factor

def make_sap(mol, sap_basis):
    '''Superposition of atomic potentials (SAP) potential matrix

    Args:
        mol : MoleBase object
            molecule for which SAP is computed
        sap_basis : dict
            SAP basis

    Returns:
        Vsap : ndarray
            SAP potential matrix
    '''
    from pyscf.gto.mole import fakemol_for_cgtf_charge

    atom_coords = numpy.asarray([coord[1] for coord in mol._atom], dtype=float)
    atoms = [coord[0] for coord in mol._atom]

    # charge sumcheck
    Z_eff = sum([numpy.sum(sap_basis[a][:,1]) for a in atoms])
    if numpy.abs(Z_eff + mol.nelectron) > 1e-6:
        logger.warn(
            mol,
            '\n'.join(['SAP basis coefficients must be equal or close'
            + f'to total electronic charge: {Z_eff} !≃ {mol.nelectron}',
            f'Check fails with value {numpy.abs(Z_eff + mol.nelectron)}']))

    V = numpy.zeros((mol.nao_nr(), mol.nao_nr()))
    cmol = mol.copy()
    nbas = cmol.nbas

    for i, atom in enumerate(atoms):
        expnt = sap_basis[atom][:,0]
        coeff = sap_basis[atom][:,1]
        nucleon_fakemol = fakemol_for_cgtf_charge(
            numpy.asarray([atom_coords[i]], dtype=float),
            expnt,
            coeff)

        cmol += nucleon_fakemol

    shls_slice = (0, nbas, 0, nbas, nbas, cmol.nbas)
    int3c2e = cmol.intor('int3c2e', comp=1, shls_slice=shls_slice)
    V = -numpy.einsum('pqk->pq', int3c2e)

    return V

# full density matrix for RHF
def make_rdm1(mo_coeff, mo_occ, **kwargs):
    '''One-particle density matrix in AO representation

    Args:
        mo_coeff : 2D ndarray
            Orbital coefficients. Each column is one orbital.
        mo_occ : 1D ndarray
            Occupancy
    Returns:
        One-particle density matrix, 2D ndarray
    '''
    mocc = mo_coeff[:,mo_occ>0]
    dm = (mocc*mo_occ[mo_occ>0]).dot(mocc.conj().T)
    return lib.tag_array(dm, mo_coeff=mo_coeff, mo_occ=mo_occ)

def make_rdm2(mo_coeff, mo_occ, **kwargs):
    '''Two-particle density matrix in AO representation

    NOTE the indices of the two-particle density matrix is ordered to

    dm2[p,q,r,s] = <q^+ s^+ r p>.

    HF energy can be computed
    E = einsum('pq,qp', hcore, 1pdm) + einsum('pqrs,pqrs', eri, 2pdm) / 2
    where h1[p,q] = <p|h|q> and eri[p,q,r,s] = (pq|rs)
to make the density matrix consistent with the density matrix obtained
    from post-HF methods,

    Args:
        mo_coeff : 2D ndarray
            Orbital coefficients. Each column is one orbital.
        mo_occ : 1D ndarray
            Occupancy
    Returns:
        Two-particle density matrix, 4D ndarray
    '''
    dm1 = make_rdm1(mo_coeff, mo_occ, **kwargs)
    dm2 = (numpy.einsum('ij,kl->ijkl', dm1, dm1)
         - numpy.einsum('ij,kl->iklj', dm1, dm1)/2)
    return dm2

################################################
# for general DM
# hermi = 0 : arbitrary
# hermi = 1 : hermitian
# hermi = 2 : anti-hermitian
################################################
def dot_eri_dm(eri, dm, hermi=0, with_j=True, with_k=True):
    '''Compute J, K matrices in terms of the given 2-electron integrals and
    density matrix:

    J ~ numpy.einsum('pqrs,qp->rs', eri, dm)
    K ~ numpy.einsum('pqrs,qr->ps', eri, dm)

    Args:
        eri : ndarray
            8-fold or 4-fold ERIs or complex integral array with N^4 elements
            (N is the number of orbitals)
        dm : ndarray or list of ndarrays
            A density matrix or a list of density matrices

    Kwargs:
        hermi : int
            Whether J, K matrix is hermitian

            | 0 : no hermitian or symmetric
            | 1 : hermitian
            | 2 : anti-hermitian

    Returns:
        Depending on the given dm, the function returns one J and one K matrix,
        or a list of J matrices and a list of K matrices, corresponding to the
        input density matrices.

    Examples:

    >>> from pyscf import gto, scf
    >>> from pyscf.scf import _vhf
    >>> mol = gto.M(atom='H 0 0 0; H 0 0 1.1')
    >>> eri = _vhf.int2e_sph(mol._atm, mol._bas, mol._env)
    >>> dms = numpy.random.random((3,mol.nao_nr(),mol.nao_nr()))
    >>> j, k = scf.hf.dot_eri_dm(eri, dms, hermi=0)
    >>> print(j.shape)
    (3, 2, 2)
    '''
    dm = numpy.asarray(dm)
    nao = dm.shape[-1]
    if eri.dtype == numpy.complex128 or eri.size == nao**4:
        eri = eri.reshape((nao,)*4)
        dms = dm.reshape(-1,nao,nao)
        vj = vk = None
        if with_j:
            vj = numpy.einsum('ijkl,xji->xkl', eri, dms)
            vj = vj.reshape(dm.shape)
        if with_k:
            vk = numpy.einsum('ijkl,xjk->xil', eri, dms)
            vk = vk.reshape(dm.shape)
    else:
        vj, vk = _vhf.incore(eri, dm.real, hermi, with_j, with_k)
        if dm.dtype == numpy.complex128:
            vs = _vhf.incore(eri, dm.imag, 0, with_j, with_k)
            if with_j:
                vj = vj + vs[0] * 1j
            if with_k:
                vk = vk + vs[1] * 1j
    return vj, vk


def get_jk(mol, dm, hermi=1, vhfopt=None, with_j=True, with_k=True, omega=None):
    '''Compute J, K matrices for all input density matrices

    Args:
        mol : an instance of :class:`Mole`

        dm : ndarray or list of ndarrays
            A density matrix or a list of density matrices

    Kwargs:
        hermi : int
            Whether J, K matrix is hermitian

            | 0 : not hermitian and not symmetric
            | 1 : hermitian or symmetric
            | 2 : anti-hermitian

        vhfopt :
            A class which holds precomputed quantities to optimize the
            computation of J, K matrices

        with_j : boolean
            Whether to compute J matrices

        with_k : boolean
            Whether to compute K matrices

        omega : float
            Parameter of range-separated Coulomb operator.
            When omega is 0 (or None), integrals are computed with the full-range Coulomb potential.
            When it is larger than zero, integrals are evaluated with the long-range
            Coulomb potential erf( omega * r12 ) / r12. When omega is smaller
            than 0, short-range Coulomb potential erfc( omega * r12 ) / r12 is applied.

    Returns:
        Depending on the given dm, the function returns one J and one K matrix,
        or a list of J matrices and a list of K matrices, corresponding to the
        input density matrices.

    Examples:

    >>> from pyscf import gto, scf
    >>> from pyscf.scf import _vhf
    >>> mol = gto.M(atom='H 0 0 0; H 0 0 1.1')
    >>> dms = numpy.random.random((3,mol.nao_nr(),mol.nao_nr()))
    >>> j, k = scf.hf.get_jk(mol, dms, hermi=0)
    >>> print(j.shape)
    (3, 2, 2)
    '''
    dm = numpy.asarray(dm, order='C')
    dm_shape = dm.shape
    dm_dtype = dm.dtype
    nao = dm_shape[-1]

    if dm_dtype == numpy.complex128:
        dm = numpy.vstack((dm.real, dm.imag)).reshape(-1,nao,nao)
        hermi = 0

    with mol.with_range_coulomb(omega):
        vj, vk = _vhf.direct(dm, mol._atm, mol._bas, mol._env,
                             vhfopt, hermi, mol.cart, with_j, with_k)

    if dm_dtype == numpy.complex128:
        if with_j:
            vj = vj.reshape((2,) + dm_shape)
            vj = vj[0] + vj[1] * 1j
        if with_k:
            vk = vk.reshape((2,) + dm_shape)
            vk = vk[0] + vk[1] * 1j
    else:
        if with_j:
            vj = vj.reshape(dm_shape)
        if with_k:
            vk = vk.reshape(dm_shape)
    return vj, vk


def get_veff(mol, dm, dm_last=None, vhf_last=None, hermi=1, vhfopt=None):
    '''Hartree-Fock potential matrix for the given density matrix

    Args:
        mol : an instance of :class:`Mole`

        dm : ndarray or list of ndarrays
            A density matrix or a list of density matrices

    Kwargs:
        dm_last : ndarray or a list of ndarrays or 0
            The density matrix baseline.  If not 0, this function computes the
            increment of HF potential w.r.t. the reference HF potential matrix.
        vhf_last : ndarray or a list of ndarrays or 0
            The reference HF potential matrix.
        hermi : int
            Whether J, K matrix is hermitian

            | 0 : no hermitian or symmetric
            | 1 : hermitian
            | 2 : anti-hermitian

        vhfopt :
            A class which holds precomputed quantities to optimize the
            computation of J, K matrices

    Returns:
        matrix Vhf = 2*J - K.  Vhf can be a list matrices, corresponding to the
        input density matrices.

    Examples:

    >>> import numpy
    >>> from pyscf import gto, scf
    >>> from pyscf.scf import _vhf
    >>> mol = gto.M(atom='H 0 0 0; H 0 0 1.1')
    >>> dm0 = numpy.random.random((mol.nao_nr(),mol.nao_nr()))
    >>> vhf0 = scf.hf.get_veff(mol, dm0, hermi=0)
    >>> dm1 = numpy.random.random((mol.nao_nr(),mol.nao_nr()))
    >>> vhf1 = scf.hf.get_veff(mol, dm1, hermi=0)
    >>> vhf2 = scf.hf.get_veff(mol, dm1, dm_last=dm0, vhf_last=vhf0, hermi=0)
    >>> numpy.allclose(vhf1, vhf2)
    True
    '''
    if dm_last is None:
        vj, vk = get_jk(mol, numpy.asarray(dm), hermi, vhfopt)
        return vj - vk * .5
    else:
        ddm = numpy.asarray(dm) - numpy.asarray(dm_last)
        vj, vk = get_jk(mol, ddm, hermi, vhfopt)
        return vj - vk * .5 + numpy.asarray(vhf_last)

def get_fock(mf, h1e=None, s1e=None, vhf=None, dm=None, cycle=-1, diis=None,
             diis_start_cycle=None, level_shift_factor=None, damp_factor=None,
             fock_last=None):
    '''F = h^{core} + V^{HF}

    Special treatment (damping, DIIS, or level shift) will be applied to the
    Fock matrix if diis and cycle is specified (The two parameters are passed
    to get_fock function during the SCF iteration)

    Kwargs:
        h1e : 2D ndarray
            Core hamiltonian
        s1e : 2D ndarray
            Overlap matrix, for DIIS
        vhf : 2D ndarray
            HF potential matrix
        dm : 2D ndarray
            Density matrix, for DIIS
        cycle : int
            Then present SCF iteration step, for DIIS
        diis : an object of :attr:`SCF.DIIS` class
            DIIS object to hold intermediate Fock and error vectors
        diis_start_cycle : int
            The step to start DIIS.  Default is 0.
        level_shift_factor : float or int
            Level shift (in AU) for virtual space.  Default is 0.
    '''
    if h1e is None: h1e = mf.get_hcore()
    if vhf is None: vhf = mf.get_veff(mf.mol, dm)
    f = h1e + vhf
    if cycle < 0 and diis is None:  # Not inside the SCF iteration
        return f

    if diis_start_cycle is None:
        diis_start_cycle = mf.diis_start_cycle
    if level_shift_factor is None:
        level_shift_factor = mf.level_shift
    if damp_factor is None:
        damp_factor = mf.damp
    if s1e is None: s1e = mf.get_ovlp()
    if dm is None: dm = mf.make_rdm1()

    if 0 <= cycle < diis_start_cycle-1 and abs(damp_factor) > 1e-4 and fock_last is not None:
        f = damping(f, fock_last, damp_factor)
    if diis is not None and cycle >= diis_start_cycle:
        f = diis.update(s1e, dm, f, mf, h1e, vhf, f_prev=fock_last)
    if abs(level_shift_factor) > 1e-4:
        f = level_shift(s1e, dm*.5, f, level_shift_factor)
    return f

def get_occ(mf, mo_energy=None, mo_coeff=None):
    '''Label the occupancies for each orbital

    Kwargs:
        mo_energy : 1D ndarray
            Obital energies

        mo_coeff : 2D ndarray
            Obital coefficients

    Examples:

    >>> from pyscf import gto, scf
    >>> mol = gto.M(atom='H 0 0 0; F 0 0 1.1')
    >>> mf = scf.hf.SCF(mol)
    >>> energy = numpy.array([-10., -1., 1, -2., 0, -3])
    >>> mf.get_occ(energy)
    array([2, 2, 0, 2, 2, 2])
    '''
    if mo_energy is None: mo_energy = mf.mo_energy
    e_idx = numpy.argsort(mo_energy)
    e_sort = mo_energy[e_idx]
    nmo = mo_energy.size
    mo_occ = numpy.zeros_like(mo_energy)
    nocc = mf.mol.nelectron // 2
    mo_occ[e_idx[:nocc]] = 2
    if mf.verbose >= logger.INFO and nocc < nmo:
        if e_sort[nocc-1]+1e-3 > e_sort[nocc]:
            logger.warn(mf, 'HOMO %.15g == LUMO %.15g',
                        e_sort[nocc-1], e_sort[nocc])
        else:
            logger.info(mf, '  HOMO = %.15g  LUMO = %.15g',
                        e_sort[nocc-1], e_sort[nocc])

    if mf.verbose >= logger.DEBUG:
        numpy.set_printoptions(threshold=nmo)
        logger.debug(mf, '  mo_energy =\n%s', mo_energy)
        numpy.set_printoptions(threshold=1000)
    return mo_occ

def get_grad(mo_coeff, mo_occ, fock_ao):
    '''RHF orbital gradients

    Args:
        mo_coeff : 2D ndarray
            Obital coefficients
        mo_occ : 1D ndarray
            Orbital occupancy
        fock_ao : 2D ndarray
            Fock matrix in AO representation

    Returns:
        Gradients in MO representation.  It's a num_occ*num_vir vector.
    '''
    occidx = mo_occ > 0
    viridx = ~occidx
    g = mo_coeff[:,viridx].conj().T.dot(
        fock_ao.dot(mo_coeff[:,occidx])) * 2
    return g.ravel()


def analyze(mf, verbose=logger.DEBUG, with_meta_lowdin=WITH_META_LOWDIN,
            origin=None, **kwargs):
    '''Analyze the given SCF object:  print orbital energies, occupancies;
    print orbital coefficients; Mulliken population analysis; Diople moment.
    '''
    from pyscf.lo import orth
    from pyscf.tools import dump_mat
    mo_energy = mf.mo_energy
    mo_occ = mf.mo_occ
    mo_coeff = mf.mo_coeff
    log = logger.new_logger(mf, verbose)

    if log.verbose >= logger.NOTE:
        mf.dump_scf_summary(log)
        log.note('**** MO energy ****')
        for i,c in enumerate(mo_occ):
            log.note('MO #%-3d energy= %-18.15g occ= %g', i+MO_BASE,
                     mo_energy[i], c)

    ovlp_ao = mf.get_ovlp()
    if verbose >= logger.DEBUG:
        label = mf.mol.ao_labels()
        if with_meta_lowdin:
            log.debug(' ** MO coefficients (expansion on meta-Lowdin AOs) **')
            orth_coeff = orth.orth_ao(mf.mol, 'meta_lowdin', s=ovlp_ao)
            c = reduce(numpy.dot, (orth_coeff.conj().T, ovlp_ao, mo_coeff))
        else:
            log.debug(' ** MO coefficients (expansion on AOs) **')
            c = mo_coeff
        dump_mat.dump_rec(mf.stdout, c, label, start=MO_BASE, **kwargs)
    dm = mf.make_rdm1(mo_coeff, mo_occ)
    if with_meta_lowdin:
        return (mf.mulliken_meta(mf.mol, dm, s=ovlp_ao, verbose=log),
                mf.dip_moment(mf.mol, dm, origin=origin, verbose=log))
    else:
        return (mf.mulliken_pop(mf.mol, dm, s=ovlp_ao, verbose=log),
                mf.dip_moment(mf.mol, dm, origin=origin, verbose=log))

def dump_scf_summary(mf, verbose=logger.DEBUG):
    if not mf.scf_summary:
        return

    log = logger.new_logger(mf, verbose)
    summary = mf.scf_summary
    def write(fmt, key):
        if key in summary:
            log.info(fmt, summary[key])
    log.info('**** SCF Summaries ****')
    log.info('Total Energy =                    %24.15f', mf.e_tot)
    write('Nuclear Repulsion Energy =        %24.15f', 'nuc')
    write('One-electron Energy =             %24.15f', 'e1')
    write('Two-electron Energy =             %24.15f', 'e2')
    write('Two-electron Coulomb Energy =     %24.15f', 'coul')
    write('DFT Exchange-Correlation Energy = %24.15f', 'exc')
    write('Empirical Dispersion Energy =     %24.15f', 'dispersion')
    write('PCM Polarization Energy =         %24.15f', 'epcm')
    write('EFP Energy =                      %24.15f', 'efp')
    if getattr(mf, 'entropy', None):
        log.info('(Electronic) Entropy              %24.15f', mf.entropy)
        log.info('(Electronic) Zero Point Energy    %24.15f', mf.e_zero)
        log.info('Free Energy =                     %24.15f', mf.e_free)


def mulliken_pop(mol, dm, s=None, verbose=logger.DEBUG):
    r'''Mulliken population analysis

    .. math:: M_{ij} = D_{ij} S_{ji}

    Mulliken charges

    .. math:: \delta_i = \sum_j M_{ij}

    Returns:
        A list : pop, charges

        pop : nparray
            Mulliken population on each atomic orbitals
        charges : nparray
            Mulliken charges
    '''
    if s is None: s = get_ovlp(mol)
    log = logger.new_logger(mol, verbose)
    if isinstance(dm, numpy.ndarray) and dm.ndim == 2:
        pop = numpy.einsum('ij,ji->i', dm, s).real
    else: # ROHF
        pop = numpy.einsum('ij,ji->i', dm[0]+dm[1], s).real

    log.info(' ** Mulliken pop  **')
    for i, s in enumerate(mol.ao_labels()):
        log.info('pop of  %-14s %10.5f', s, pop[i])

    log.note(' ** Mulliken atomic charges  **')
    chg = numpy.zeros(mol.natm)
    for i, s in enumerate(mol.ao_labels(fmt=None)):
        chg[s[0]] += pop[i]
    chg = mol.atom_charges() - chg
    for ia in range(mol.natm):
        symb = mol.atom_symbol(ia)
        log.note('charge of  %3d%s =   %10.5f', ia, symb, chg[ia])
    return pop, chg


def mulliken_meta(mol, dm, verbose=logger.DEBUG,
                  pre_orth_method=PRE_ORTH_METHOD, s=None):
    '''Mulliken population analysis, based on meta-Lowdin AOs.
    In the meta-lowdin, the AOs are grouped in three sets: core, valence and
    Rydberg, the orthogonalization are carried out within each subsets.

    Args:
        mol : an instance of :class:`Mole`

        dm : ndarray or 2-item list of ndarray
            Density matrix.  ROHF dm is a 2-item list of 2D array

    Kwargs:
        verbose : int or instance of :class:`lib.logger.Logger`

        pre_orth_method : str
            Pre-orthogonalization, which localized GTOs for each atom.
            To obtain the occupied and unoccupied atomic shells, there are
            three methods

            | 'ano'   : Project GTOs to ANO basis
            | 'minao' : Project GTOs to MINAO basis
            | 'scf'   : Symmetry-averaged fractional occupation atomic RHF

    Returns:
        A list : pop, charges

        pop : nparray
            Mulliken population on each atomic orbitals
        charges : nparray
            Mulliken charges
    '''
    from pyscf.lo import orth
    if s is None: s = get_ovlp(mol)
    log = logger.new_logger(mol, verbose)

    orth_coeff = orth.orth_ao(mol, 'meta_lowdin', pre_orth_method, s=s)
    c_inv = numpy.dot(orth_coeff.conj().T, s)
    if isinstance(dm, numpy.ndarray) and dm.ndim == 2:
        dm = reduce(numpy.dot, (c_inv, dm, c_inv.T.conj()))
    else:  # ROHF
        dm = reduce(numpy.dot, (c_inv, dm[0]+dm[1], c_inv.T.conj()))

    log.info(' ** Mulliken pop on meta-lowdin orthogonal AOs  **')
    return mulliken_pop(mol, dm, numpy.eye(orth_coeff.shape[0]), log)
mulliken_pop_meta_lowdin_ao = mulliken_meta


def eig(h, s):
    '''Solver for generalized eigenvalue problem

    .. math:: HC = SCE
    '''
    e, c = scipy.linalg.eigh(h, s)
    idx = numpy.argmax(abs(c.real), axis=0)
    c[:,c[idx,numpy.arange(len(e))].real<0] *= -1
    return e, c

def canonicalize(mf, mo_coeff, mo_occ, fock=None):
    '''Canonicalization diagonalizes the Fock matrix within occupied, open,
    virtual subspaces separatedly (without change occupancy).
    '''
    if fock is None:
        dm = mf.make_rdm1(mo_coeff, mo_occ)
        fock = mf.get_fock(dm=dm)
    coreidx = mo_occ == 2
    viridx = mo_occ == 0
    openidx = ~(coreidx | viridx)
    mo = numpy.empty_like(mo_coeff)
    mo_e = numpy.empty(mo_occ.size)
    for idx in (coreidx, openidx, viridx):
        if numpy.count_nonzero(idx) > 0:
            orb = mo_coeff[:,idx]
            f1 = reduce(numpy.dot, (orb.conj().T, fock, orb))
            e, c = scipy.linalg.eigh(f1)
            mo[:,idx] = numpy.dot(orb, c)
            mo_e[idx] = e
    return mo_e, mo

def dip_moment(mol, dm, unit='Debye', origin=None, verbose=logger.NOTE, **kwargs):
    r''' Dipole moment calculation

    .. math::

        \mu_x = -\sum_{\mu}\sum_{\nu} P_{\mu\nu}(\nu|x|\mu) + \sum_A Q_A X_A\\
        \mu_y = -\sum_{\mu}\sum_{\nu} P_{\mu\nu}(\nu|y|\mu) + \sum_A Q_A Y_A\\
        \mu_z = -\sum_{\mu}\sum_{\nu} P_{\mu\nu}(\nu|z|\mu) + \sum_A Q_A Z_A

    where :math:`\mu_x, \mu_y, \mu_z` are the x, y and z components of dipole
    moment

    Args:
         mol: an instance of :class:`Mole`
         dm : a 2D ndarrays density matrices
         origin : optional; length 3 list, tuple, or 1D array
            Location of the origin. By default, the point (0, 0, 0) is used.

    Return:
        A list: the dipole moment on x, y and z component
    '''

    log = logger.new_logger(mol, verbose)

    if 'unit_symbol' in kwargs:  # pragma: no cover
        log.warn('Kwarg "unit_symbol" was deprecated. It was replaced by kwarg '
                 'unit since PySCF-1.5.')
        unit = kwargs['unit_symbol']

    if not (isinstance(dm, numpy.ndarray) and dm.ndim == 2):
        # UHF density matrices
        dm = dm[0] + dm[1]

    charges = mol.atom_charges()
    coords  = mol.atom_coords()

    if origin is None:
        origin = numpy.zeros(3)
    else:
        origin = numpy.asarray(origin, dtype=numpy.float64)
    assert origin.shape == (3,)

    if mol.charge != 0:
        log.warn(f"System has nonzero charge {mol.charge}; the dipole moment is origin-dependent.\n"
                 f"Location of origin: {origin}")

    with mol.with_common_orig(origin):
        ao_dip = mol.intor_symmetric('int1e_r', comp=3)
    el_dip = numpy.einsum('xij,ji->x', ao_dip, dm).real
    nucl_dip = numpy.einsum('i,ix->x', charges, coords - origin[None, :])
    mol_dip = nucl_dip - el_dip

    if unit.upper() == 'DEBYE':
        mol_dip *= nist.AU2DEBYE
        log.note('Dipole moment(X, Y, Z, Debye): %8.5f, %8.5f, %8.5f', *mol_dip)
    else:
        log.note('Dipole moment(X, Y, Z, A.U.): %8.5f, %8.5f, %8.5f', *mol_dip)
    return mol_dip

def quad_moment(mol, dm, unit='DebyeAngstrom', origin=None,
                verbose=logger.NOTE, **kwargs):
    r''' Calculates traceless quadrupole moment tensor.

    The traceless quadrupole tensor is given by

    .. math::

        Q_{ij} &= - \frac{1}{2} \sum_{\mu \nu} P_{\mu \nu}
                \left[ 3 (\nu | r_i r_j | \mu) - \delta_{ij} (\nu | r^2 | \mu) \right] \\
               &+ \frac{1}{2} \sum_A Q_A
               \left( R_{iA} R_{jA} - \delta_{ij} \|\mathbf{R}_A\|^2  \right).

    If the molecule has a dipole, the quadrupole moment depends on the location
    of the origin. By default, the origin is taken to be (0, 0, 0), but it can
    be set manually via the keyword argument `origin`.

    Args:
         mol: an instance of :class:`Mole`
         dm : a 2D ndarrays density matrices
         origin : optional; length 3 list, tuple, or 1D array
            Location of the origin. By default, it is (0, 0, 0).

    Return:
        Traceless quadrupole tensor, 2D ndarray.
    '''

    log = logger.new_logger(mol, verbose)

    if 'unit_symbol' in kwargs:  # pragma: no cover
        log.warn('Kwarg "unit_symbol" was deprecated. It was replaced by kwarg '
                 'unit since PySCF-1.5.')
        unit = kwargs['unit_symbol']

    if not (isinstance(dm, numpy.ndarray) and dm.ndim == 2):
        # UHF density matrices
        dm = dm[0] + dm[1]

    charges = mol.atom_charges()
    coords  = mol.atom_coords()

    if origin is None:
        origin = numpy.zeros(3)
    else:
        origin = numpy.asarray(origin, dtype=numpy.float64)
    assert origin.shape == (3,)

    with mol.with_common_orig(origin):
        quad_ints = mol.intor_symmetric("int1e_rr", comp=9).reshape((3, 3, -1))
    r_nuc = coords - origin[None, :]
    elec_q = (quad_ints @ dm.ravel()).real
    nuc_q = numpy.einsum("g,gx,gy->xy", charges, r_nuc, r_nuc)
    tot_q = (nuc_q - elec_q) / 2
    tot_q_traceless = 3 * tot_q - numpy.eye(3) * numpy.trace(tot_q)

    if unit.upper() in ('DEBYEANGSTROM', 'DEBYEANG', 'DEBYEA'):
        tot_q_traceless *= nist.AU2DEBYE * nist.BOHR
        log.note('Traceless quadrupole moment (Debye*A):')
    else:
        log.note('Traceless quadrupole moment (AU):')

    with numpy.printoptions(precision=5, floatmode='fixed'):
        log.note(str())

    return tot_q_traceless

def uniq_var_indices(mo_occ):
    '''
    Indices of the unique variables for the orbital-gradients (or
    orbital-rotation) matrix.
    '''
    occidxa = mo_occ>0
    occidxb = mo_occ==2
    viridxa = ~occidxa
    viridxb = ~occidxb
    mask = (viridxa[:,None] & occidxa) | (viridxb[:,None] & occidxb)
    return mask

def pack_uniq_var(x, mo_occ):
    '''
    Extract the unique variables from the full orbital-gradients (or
    orbital-rotation) matrix
    '''
    idx = uniq_var_indices(mo_occ)
    return x[idx]

def unpack_uniq_var(dx, mo_occ):
    '''
    Fill the full orbital-gradients (or orbital-rotation) matrix with the
    unique variables.
    '''
    nmo = len(mo_occ)
    idx = uniq_var_indices(mo_occ)

    x1 = numpy.zeros((nmo,nmo), dtype=dx.dtype)
    x1[idx] = dx
    return x1 - x1.conj().T


def as_scanner(mf):
    '''Generating a scanner/solver for HF PES.

    The returned solver is a function. This function requires one argument
    "mol" as input and returns total HF energy.

    The solver will automatically use the results of last calculation as the
    initial guess of the new calculation.  All parameters assigned in the
    SCF object (DIIS, conv_tol, max_memory etc) are automatically applied in
    the solver.

    Note scanner has side effects.  It may change many underlying objects
    (_scf, with_df, with_x2c, ...) during calculation.

    Examples:

    >>> from pyscf import gto, scf
    >>> hf_scanner = scf.RHF(gto.Mole().set(verbose=0)).as_scanner()
    >>> hf_scanner(gto.M(atom='H 0 0 0; F 0 0 1.1'))
    -98.552190448277955
    >>> hf_scanner(gto.M(atom='H 0 0 0; F 0 0 1.5'))
    -98.414750424294368
    '''
    if isinstance(mf, lib.SinglePointScanner):
        return mf

    logger.info(mf, 'Create scanner for %s', mf.__class__)
    name = mf.__class__.__name__ + SCF_Scanner.__name_mixin__
    return lib.set_class(SCF_Scanner(mf), (SCF_Scanner, mf.__class__), name)

class SCF_Scanner(lib.SinglePointScanner):
    def __init__(self, mf_obj):
        self.__dict__.update(mf_obj.__dict__)
        self._last_mol_fp = mf_obj.mol.ao_loc

    def __call__(self, mol_or_geom, **kwargs):
        if isinstance(mol_or_geom, gto.MoleBase):
            mol = mol_or_geom
        else:
            mol = self.mol.set_geom_(mol_or_geom, inplace=False)

        # Cleanup intermediates associated to the previous mol object
        self.reset(mol)

        if 'dm0' in kwargs:
            dm0 = kwargs.pop('dm0')
        elif self.mo_coeff is None:
            dm0 = None
        else:
            dm0 = None
            # dm0 form last calculation may not be used in the current
            # calculation if a completely different system is given.
            # Obviously, the systems are very different if the number of
            # basis functions are different.
            # TODO: A robust check should include more comparison on
            # various attributes between current `mol` and the `mol` in
            # last calculation.
            if numpy.array_equal(self._last_mol_fp, mol.ao_loc):
                dm0 = self.make_rdm1()
            elif self.chkfile and h5py.is_hdf5(self.chkfile):
                dm0 = self.from_chk(self.chkfile)
        self.mo_coeff = None  # To avoid last mo_coeff being used by SOSCF
        e_tot = self.kernel(dm0=dm0, **kwargs)
        self._last_mol_fp = mol.ao_loc
        return e_tot


class SCF(lib.StreamObject):
    '''SCF base class.   non-relativistic RHF.

    Attributes:
        verbose : int
            Print level.  Default value equals to :class:`Mole.verbose`
        max_memory : float or int
            Allowed memory in MB.  Default equals to :class:`Mole.max_memory`
        chkfile : str
            checkpoint file to save MOs, orbital energies etc.  Writing to
            chkfile can be disabled if this attribute is set to None or False.
        conv_tol : float
            converge threshold.  Default is 1e-9
        conv_tol_grad : float
            gradients converge threshold.  Default is sqrt(conv_tol)
        max_cycle : int
            max number of iterations.  If max_cycle <= 0, SCF iteration will
            be skipped and the kernel function will compute only the total
            energy based on the initial guess. Default value is 50.
        init_guess : str
            initial guess method.  It can be one of 'minao', 'atom', 'huckel', 'hcore', '1e', 'sap', 'chkfile'.
            Default is 'minao'
        sap_basis : str or dict
            basis for SAP initial guess, either filename or path as str or
            internal format dictionary.
        DIIS : DIIS class
            The class to generate diis object.  It can be one of
            diis.SCF_DIIS, diis.ADIIS, diis.EDIIS.
        diis : boolean or object of DIIS class defined in :mod:`scf.diis`.
            Default is the object associated to the attribute :attr:`self.DIIS`.
            Set it to None/False to turn off DIIS.
            Note if this attribute is initialized as a DIIS object, the SCF driver
            will use this object in the iteration. The DIIS information (vector
            basis and error vector) will be held inside this object. When kernel
            function is called again, the old states (vector basis and error
            vector) will be reused.
        diis_space : int
            DIIS space size.  By default, 8 Fock matrices and errors vector are stored.
        diis_damp : float
            DIIS damping factor.  Default is 0.
        diis_start_cycle : int
            The step to start DIIS.  Default is 1.
        diis_file: 'str'
            File to store DIIS vectors and error vectors.
        level_shift : float or int
            Level shift (in AU) for virtual space.  Default is 0.
        direct_scf : bool
            Direct SCF is used by default.
        direct_scf_tol : float
            Direct SCF cutoff threshold.  Default is 1e-13.
        callback : function(envs_dict) => None
            callback function takes one dict as the argument which is
            generated by the builtin function :func:`locals`, so that the
            callback function can access all local variables in the current
            environment.
        conv_check : bool
            An extra cycle to check convergence after SCF iterations.
        check_convergence : function(envs) => bool
            A hook for overloading convergence criteria in SCF iterations.

    Saved results:

        converged : bool
            Whether the SCF iteration converged
        e_tot : float
            Total HF energy (electronic energy plus nuclear repulsion)
        mo_energy :
            Orbital energies
        mo_occ
            Orbital occupancy
        mo_coeff
            Orbital coefficients
        cycles : int
            The number of iteration cycles performed

    Examples:

    >>> mol = gto.M(atom='H 0 0 0; H 0 0 1.1', basis='cc-pvdz')
    >>> mf = scf.hf.SCF(mol)
    >>> mf.verbose = 0
    >>> mf.level_shift = .4
    >>> mf.scf()
    -1.0811707843775884
    '''
    conv_tol = getattr(__config__, 'scf_hf_SCF_conv_tol', 1e-9)
    conv_tol_grad = getattr(__config__, 'scf_hf_SCF_conv_tol_grad', None)
    conv_tol_cpscf = getattr(__config__, 'scf_hf_SCF_conv_tol_cpscf', 1e-8)
    max_cycle = getattr(__config__, 'scf_hf_SCF_max_cycle', 50)
    init_guess = getattr(__config__, 'scf_hf_SCF_init_guess', 'minao')
    sap_basis = 'sapgrasplarge' # Basis for SAP initial guess
    disp = None  # for DFT-D3 and DFT-D4

    # To avoid diis pollution from previous run, self.diis should not be
    # initialized as DIIS instance here
    DIIS = diis.SCF_DIIS
    diis = getattr(__config__, 'scf_hf_SCF_diis', True)
    diis_space = getattr(__config__, 'scf_hf_SCF_diis_space', 8)
    diis_damp = getattr(__config__, 'scf_hf_SCF_diis_damp', 0)
    # need > 0 if initial DM is numpy.zeros array
    diis_start_cycle = getattr(__config__, 'scf_hf_SCF_diis_start_cycle', 1)
    diis_file = None
    diis_space_rollback = 0

    damp = getattr(__config__, 'scf_hf_SCF_damp', 0)
    level_shift = getattr(__config__, 'scf_hf_SCF_level_shift', 0)
    direct_scf = getattr(__config__, 'scf_hf_SCF_direct_scf', True)
    direct_scf_tol = getattr(__config__, 'scf_hf_SCF_direct_scf_tol', 1e-13)
    conv_check = getattr(__config__, 'scf_hf_SCF_conv_check', True)

    callback = None

    _keys = {
        'conv_tol', 'conv_tol_grad', 'conv_tol_cpscf', 'max_cycle', 'init_guess',
        'sap_basis', 'DIIS', 'diis', 'diis_space', 'diis_damp', 'diis_start_cycle',
        'diis_file', 'diis_space_rollback', 'damp', 'level_shift',
        'direct_scf', 'direct_scf_tol', 'conv_check', 'callback',
        'mol', 'chkfile', 'mo_energy', 'mo_coeff', 'mo_occ',
        'e_tot', 'converged', 'cycles', 'scf_summary', 'opt',
        'disp', 'disp_with_3body',
    }

    def __init__(self, mol):
        if not mol._built:
            sys.stderr.write('Warning: %s must be initialized before calling SCF.\n'
                             'Initialize %s in %s\n' % (mol, mol, self))
            mol.build()
        self.mol = mol
        self.verbose = mol.verbose
        self.max_memory = mol.max_memory
        self.stdout = mol.stdout

        # If chkfile is muted, SCF intermediates will not be dumped anywhere.
        if MUTE_CHKFILE:
            self.chkfile = None
        else:
            # the chkfile will be removed automatically, to save the chkfile, assign a
            # filename to self.chkfile
            self._chkfile = tempfile.NamedTemporaryFile(dir=lib.param.TMPDIR)
            self.chkfile = self._chkfile.name

##################################################
# don't modify the following attributes, they are not input options
        self.mo_energy = None
        self.mo_coeff = None
        self.mo_occ = None
        self.e_tot = 0
        self.converged = False
        self.cycles = 0
        self.scf_summary = {}

        self._opt = {None: None}
        self._eri = None # Note: self._eri requires large amount of memory

    __getstate__, __setstate__ = lib.generate_pickle_methods(
            excludes=('chkfile', '_chkfile', '_opt', '_eri', 'callback'))

    def __getattr__(self, key):
        '''Accessing methods post-HF methods or mean-field properties'''
        # Import all available modules, then retry accessing the attribute
        from pyscf import __all__  # noqa
        return object.__getattribute__(self, key)

    def check_sanity(self):
        s1e = self.get_ovlp()
        cond = lib.cond(s1e)
        logger.debug(self, 'cond(S) = %s', cond)
        if numpy.max(cond)*1e-17 > self.conv_tol:
            logger.warn(self, 'Singularity detected in overlap matrix (condition number = %4.3g). '
                        'SCF may be inaccurate and hard to converge.', numpy.max(cond))
        return super().check_sanity()

    def build(self, mol=None):
        if mol is None: mol = self.mol
        if self.verbose >= logger.WARN:
            self.check_sanity()
        return self

    @property
    def opt(self):
        return self._opt[None]
    @opt.setter
    def opt(self, x):
        self._opt[None] = x

    def dump_flags(self, verbose=None):
        log = logger.new_logger(self, verbose)
        if log.verbose < logger.INFO:
            return self

        log.info('\n')
        log.info('******** %s ********', self.__class__)
        log.info('method = %s', self.__class__.__name__)
        log.info('initial guess = %s', self.init_guess)
        log.info('damping factor = %g', self.damp)
        log.info('level_shift factor = %s', self.level_shift)
        if isinstance(self.diis, lib.diis.DIIS):
            log.info('DIIS = %s', self.diis)
            log.info('diis_start_cycle = %d', self.diis_start_cycle)
            log.info('diis_space = %d', self.diis.space)
            if getattr(self.diis, 'damp', None):
                log.info('diis_damp = %g', self.diis.damp)
        elif self.diis:
            log.info('DIIS = %s', self.DIIS)
            log.info('diis_start_cycle = %d', self.diis_start_cycle)
            log.info('diis_space = %d', self.diis_space)
            log.info('diis_damp = %g', self.diis_damp)
        else:
            log.info('DIIS disabled')
        log.info('SCF conv_tol = %g', self.conv_tol)
        log.info('SCF conv_tol_grad = %s', self.conv_tol_grad)
        log.info('SCF max_cycles = %d', self.max_cycle)
        log.info('direct_scf = %s', self.direct_scf)
        if self.direct_scf:
            log.info('direct_scf_tol = %g', self.direct_scf_tol)
        if self.chkfile:
            log.info('chkfile to save SCF result = %s', self.chkfile)
        log.info('max_memory %d MB (current use %d MB)',
                 self.max_memory, lib.current_memory()[0])
        return self


    def _eigh(self, h, s):
        return eig(h, s)

    @lib.with_doc(eig.__doc__)
    def eig(self, h, s):
        # An intermediate call to self._eigh so that the modification to eig function
        # can be applied on different level.  Different SCF modules like RHF/UHF
        # redefine only the eig solver and leave the other modifications (like removing
        # linear dependence, sorting eigenvalues) to low level ._eigh
        return self._eigh(h, s)

    def get_hcore(self, mol=None):
        if mol is None: mol = self.mol
        return get_hcore(mol)

    def get_ovlp(self, mol=None):
        if mol is None: mol = self.mol
        return get_ovlp(mol)

    get_fock = get_fock
    get_occ = get_occ

    @lib.with_doc(get_grad.__doc__)
    def get_grad(self, mo_coeff, mo_occ, fock=None):
        if fock is None:
            dm1 = self.make_rdm1(mo_coeff, mo_occ)
            fock = self.get_hcore(self.mol) + self.get_veff(self.mol, dm1)
        return get_grad(mo_coeff, mo_occ, fock)

    def dump_chk(self, envs_or_file):
        '''Serialize the SCF object and save it to the specified chkfile.

        Args:
            envs_or_file:
                If this argument is a file path, the serialized SCF object is
                saved to the file specified by this argument.
                If this attribute is a dict (created by locals()), the necessary
                variables are saved to the file specified by the attribute mf.chkfile.
        '''
        if isinstance(envs_or_file, str):
            chkfile.dump_scf(self.mol, envs_or_file, self.e_tot, self.mo_energy,
                             self.mo_coeff, self.mo_occ)
        elif self.chkfile:
            envs = envs_or_file
            chkfile.dump_scf(self.mol, self.chkfile,
                             envs['e_tot'], envs['mo_energy'],
                             envs['mo_coeff'], envs['mo_occ'],
                             overwrite_mol=False)
        return self

    @lib.with_doc(init_guess_by_minao.__doc__)
    def init_guess_by_minao(self, mol=None):
        if mol is None: mol = self.mol
        logger.info(self, 'Initial guess from minao.')
        return init_guess_by_minao(mol)

    @lib.with_doc(init_guess_by_atom.__doc__)
    def init_guess_by_atom(self, mol=None):
        if mol is None: mol = self.mol
        logger.info(self, 'Initial guess from superposition of atomic densities.')
        return init_guess_by_atom(mol)

    @lib.with_doc(init_guess_by_huckel.__doc__)
    def init_guess_by_huckel(self, mol=None):
        if mol is None: mol = self.mol
        logger.info(self, 'Initial guess from on-the-fly Huckel, doi:10.1021/acs.jctc.8b01089.')
        mo_energy, mo_coeff = _init_guess_huckel_orbitals(mol, updated_rule=False)
        mo_occ = self.get_occ(mo_energy, mo_coeff)
        return self.make_rdm1(mo_coeff, mo_occ)

    @lib.with_doc(init_guess_by_mod_huckel.__doc__)
    def init_guess_by_mod_huckel(self, updated_rule, mol=None):
        if mol is None: mol = self.mol
        logger.info(self, '''Initial guess from on-the-fly Huckel, doi:10.1021/acs.jctc.8b01089,
employing the updated GWH rule from doi:10.1021/ja00480a005.''')
        mo_energy, mo_coeff = _init_guess_huckel_orbitals(mol, updated_rule=True)
        mo_occ = self.get_occ(mo_energy, mo_coeff)
        return self.make_rdm1(mo_coeff, mo_occ)

    @lib.with_doc(init_guess_by_1e.__doc__)
    def init_guess_by_1e(self, mol=None):
        if mol is None: mol = self.mol
        logger.info(self, 'Initial guess from hcore.')
        h1e = self.get_hcore(mol)
        s1e = self.get_ovlp(mol)
        mo_energy, mo_coeff = self.eig(h1e, s1e)
        mo_occ = self.get_occ(mo_energy, mo_coeff)
        return self.make_rdm1(mo_coeff, mo_occ)

    @lib.with_doc(init_guess_by_sap.__doc__)
    def init_guess_by_sap(self, mol=None, **kwargs):
        from pyscf.gto.basis import load
        if mol is None: mol = self.mol
        sap_basis = self.sap_basis
        logger.info(self, '''Initial guess from superposition of atomic potentials (doi:10.1021/acs.jctc.8b01089)
This is the Gaussian fit version as described in doi:10.1063/5.0004046.''')
        if isinstance(sap_basis, str):
            atoms = [coord[0] for coord in mol._atom]
            sapbas = {}
            for atom in set(atoms):
                single_element_bs = load(sap_basis, atom)
                if isinstance(single_element_bs, dict):
                    sapbas[atom] = numpy.asarray(single_element_bs[atom][0][1:], dtype=float)
                else:
                    sapbas[atom] = numpy.asarray(single_element_bs[0][1:], dtype=float)
            logger.note(self, f'Found SAP basis {sap_basis.split("/")[-1]}')
        elif isinstance(sap_basis, dict):
            sapbas = {}
            for key in sap_basis:
                sapbas[key] = numpy.asarray(sap_basis[key][0][1:], dtype=float)
        else:
            logger.error(self, 'sap_basis is of an unexpected datatype.')
        return init_guess_by_sap(mol, sap_basis=sapbas, **kwargs)

    @lib.with_doc(init_guess_by_chkfile.__doc__)
    def init_guess_by_chkfile(self, chkfile=None, project=None):
        if chkfile is None: chkfile = self.chkfile
        return init_guess_by_chkfile(self.mol, chkfile, project=project)
    def from_chk(self, chkfile=None, project=None):
        return self.init_guess_by_chkfile(chkfile, project)
    from_chk.__doc__ = init_guess_by_chkfile.__doc__

    def get_init_guess(self, mol=None, key='minao', **kwargs):
        if not isinstance(key, str):
            return key

        key = key.lower()
        if mol is None:
            mol = self.mol
        if key == '1e' or key == 'hcore':
            dm = self.init_guess_by_1e(mol)
        elif key == 'huckel':
            dm = self.init_guess_by_huckel(mol)
        elif key == 'mod_huckel':
            dm = self.init_guess_by_mod_huckel(mol)
        elif getattr(mol, 'natm', 0) == 0:
            logger.info(self, 'No atom found in mol. Use 1e initial guess')
            dm = self.init_guess_by_1e(mol)
        elif key == 'atom':
            dm = self.init_guess_by_atom(mol)
        elif key == 'vsap' and hasattr(self, 'init_guess_by_vsap'):
            # Only available for DFT objects
            dm = self.init_guess_by_vsap(mol)
        elif key == 'sap':
            dm = self.init_guess_by_sap(mol, **kwargs)
        elif key[:3] == 'chk':
            try:
                dm = self.init_guess_by_chkfile()
            except (IOError, KeyError):
                logger.warn(self, 'Fail in reading %s. Use MINAO initial guess',
                            self.chkfile)
                dm = self.init_guess_by_minao(mol)
        else:
            dm = self.init_guess_by_minao(mol)
        return dm

    make_rdm1 = lib.module_method(make_rdm1, absences=['mo_coeff', 'mo_occ'])
    make_rdm2 = lib.module_method(make_rdm2, absences=['mo_coeff', 'mo_occ'])
    energy_elec = energy_elec
    energy_tot = energy_tot

    do_disp = dispersion.check_disp
    get_dispersion = dispersion.get_dispersion

    def energy_nuc(self):
        return self.mol.enuc

    # A hook for overloading convergence criteria in SCF iterations. Assigning
    # a function
    #   f(envs) => bool
    # to check_convergence can overwrite the default convergence criteria
    check_convergence = None

    def scf(self, dm0=None, **kwargs):
        '''SCF main driver

        Kwargs:
            dm0 : ndarray
                If given, it will be used as the initial guess density matrix

        Examples:

        >>> import numpy
        >>> from pyscf import gto, scf
        >>> mol = gto.M(atom='H 0 0 0; F 0 0 1.1')
        >>> mf = scf.hf.SCF(mol)
        >>> dm_guess = numpy.eye(mol.nao_nr())
        >>> mf.kernel(dm_guess)
        converged SCF energy = -98.5521904482821
        -98.552190448282104
        '''
        cput0 = (logger.process_clock(), logger.perf_counter())

        self.dump_flags()
        self.build(self.mol)

        if dm0 is None and self.mo_coeff is not None and self.mo_occ is not None:
            # Initial guess from existing wavefunction
            dm0 = self.make_rdm1()

        if self.max_cycle > 0 or self.mo_coeff is None:
            self.converged, self.e_tot, \
                    self.mo_energy, self.mo_coeff, self.mo_occ = \
                    kernel(self, self.conv_tol, self.conv_tol_grad,
                           dm0=dm0, callback=self.callback,
                           conv_check=self.conv_check, **kwargs)
        else:
            # Avoid to update SCF orbitals in the non-SCF initialization
            # (issue #495).  But run regular SCF for initial guess if SCF was
            # not initialized.
            self.e_tot = kernel(self, self.conv_tol, self.conv_tol_grad,
                                dm0=dm0, callback=self.callback,
                                conv_check=self.conv_check, **kwargs)[1]

        logger.timer(self, 'SCF', *cput0)
        self._finalize()
        return self.e_tot
    kernel = lib.alias(scf, alias_name='kernel')

    def _finalize(self):
        '''Hook for dumping results and clearing up the object.'''
        if self.converged:
            logger.note(self, 'converged SCF energy = %.15g', self.e_tot)
        else:
            logger.note(self, 'SCF not converged.')
            logger.note(self, 'SCF energy = %.15g', self.e_tot)
        return self

    def init_direct_scf(self, mol=None):
        if mol is None: mol = self.mol
        # Integrals < direct_scf_tol may be set to 0 in int2e.
        # Higher accuracy is required for Schwartz inequality prescreening.
        cpu0 = (logger.process_clock(), logger.perf_counter())
        opt = _vhf._VHFOpt(mol, 'int2e', 'CVHFnrs8_prescreen',
                          'CVHFnr_int2e_q_cond', 'CVHFnr_dm_cond',
                           self.direct_scf_tol)
        logger.timer(self, 'init_direct_scf', *cpu0)
        return opt

    @lib.with_doc(get_jk.__doc__)
    def get_jk(self, mol=None, dm=None, hermi=1, with_j=True, with_k=True,
               omega=None):
        if mol is None: mol = self.mol
        if dm is None: dm = self.make_rdm1()
        cpu0 = (logger.process_clock(), logger.perf_counter())
        if self.direct_scf and self._opt.get(omega) is None:
            # Be careful that opt has to be initialized with a proper setting of
            # omega. opt of regular ERI and SR ERI are incompatible since cint 5.4.0
            with mol.with_range_coulomb(omega):
                self._opt[omega] = self.init_direct_scf(mol)
        vhfopt = self._opt.get(omega)

        if with_j and with_k:
            vj, vk = get_jk(mol, dm, hermi, vhfopt, with_j, with_k, omega)
        else:
            if with_j:
                prescreen = 'CVHFnrs8_vj_prescreen'
            else:
                prescreen = 'CVHFnrs8_vk_prescreen'
            with lib.temporary_env(vhfopt, prescreen=prescreen):
                vj, vk = get_jk(mol, dm, hermi, vhfopt, with_j, with_k, omega)

        logger.timer(self, 'vj and vk', *cpu0)
        return vj, vk

    def get_j(self, mol=None, dm=None, hermi=1, omega=None):
        '''Compute J matrices for all input density matrices
        '''
        return self.get_jk(mol, dm, hermi, with_k=False, omega=omega)[0]

    def get_k(self, mol=None, dm=None, hermi=1, omega=None):
        '''Compute K matrices for all input density matrices
        '''
        return self.get_jk(mol, dm, hermi, with_j=False, omega=omega)[1]

    @lib.with_doc(get_veff.__doc__)
    def get_veff(self, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1):
        # Be careful with the effects of :attr:`SCF.direct_scf` on this function
        if mol is None: mol = self.mol
        if dm is None: dm = self.make_rdm1()
        if self.direct_scf:
            ddm = numpy.asarray(dm) - dm_last
            vj, vk = self.get_jk(mol, ddm, hermi=hermi)
            return vhf_last + vj - vk * .5
        else:
            vj, vk = self.get_jk(mol, dm, hermi=hermi)
            return vj - vk * .5

    @lib.with_doc(analyze.__doc__)
    def analyze(self, verbose=None, with_meta_lowdin=WITH_META_LOWDIN,
                **kwargs):
        if verbose is None: verbose = self.verbose
        return analyze(self, verbose, with_meta_lowdin, **kwargs)

    dump_scf_summary = dump_scf_summary

    @lib.with_doc(mulliken_pop.__doc__)
    def mulliken_pop(self, mol=None, dm=None, s=None, verbose=logger.DEBUG):
        if mol is None: mol = self.mol
        if dm is None: dm = self.make_rdm1()
        if s is None: s = self.get_ovlp(mol)
        return mulliken_pop(mol, dm, s=s, verbose=verbose)

    @lib.with_doc(mulliken_meta.__doc__)
    def mulliken_meta(self, mol=None, dm=None, verbose=logger.DEBUG,
                      pre_orth_method=PRE_ORTH_METHOD, s=None):
        if mol is None: mol = self.mol
        if dm is None: dm = self.make_rdm1()
        if s is None: s = self.get_ovlp(mol)
        return mulliken_meta(mol, dm, s=s, verbose=verbose,
                             pre_orth_method=pre_orth_method)
    def pop(self, *args, **kwargs):
        return self.mulliken_meta(*args, **kwargs)
    pop.__doc__ = mulliken_meta.__doc__
    mulliken_pop_meta_lowdin_ao = pop

    canonicalize = canonicalize

    @lib.with_doc(dip_moment.__doc__)
    def dip_moment(self, mol=None, dm=None, unit='Debye', origin=None, verbose=logger.NOTE,
                   **kwargs):
        if mol is None: mol = self.mol
        if dm is None: dm =self.make_rdm1()
        return dip_moment(mol, dm, unit, origin=origin, verbose=verbose, **kwargs)

    @lib.with_doc(quad_moment.__doc__)
    def quad_moment(self, mol=None, dm=None, unit='DebyeAngstrom', origin=None,
                verbose=logger.NOTE, **kwargs):
        if mol is None: mol = self.mol
        if dm is None: dm =self.make_rdm1()
        return quad_moment(mol, dm, unit=unit, origin=origin,
                verbose=verbose, **kwargs)

    def _is_mem_enough(self):
        nbf = self.mol.nao_nr()
        return nbf**4/1e6+lib.current_memory()[0] < self.max_memory*.95

    def density_fit(self, auxbasis=None, with_df=None, only_dfj=False):
        import pyscf.df.df_jk
        if self.istype('_Solvation'):
            logger.warn(self,
                'It is recommended to call density_fit() before applying a solvent model. '
                'Calling density_fit() after the solvent model may result in '
                'incorrect nuclear gradients, TDDFT and other methods.')
        return pyscf.df.df_jk.density_fit(self, auxbasis, with_df, only_dfj)

    def multigrid_numint(self, margin=None, mesh=None):
        '''Apply the MultiGrid algorithm for XC numerical integartion.

        Kwargs:
            margin : float
                A box will be created to enclose the molecule, with the molecule
                positioned at the center. "margin" specifies the distance from
                the edge of the molecule to the edge of the box. If not provided,
                a default margin is estimated, which ensures that the electron
                density decays to approximately 1e-7 at the boundary of the box.
            mesh : (3,) ndarray
                The number of mesh grids along each axis. If not specified, the
                number of mesh grids will be estimated based on the basis sets
                and the margin.
        '''
        raise NotImplementedError

    def sfx2c1e(self):
        import pyscf.x2c.sfx2c1e
        return pyscf.x2c.sfx2c1e.sfx2c1e(self)
    x2c1e = sfx2c1e
    x2c = x2c1e

    def newton(self):
        '''Create an SOSCF object based on the mean-field object'''
        from pyscf.soscf import newton_ah
        return newton_ah.newton(self)

    def remove_soscf(self):
        '''Remove the SOSCF decorator'''
        from pyscf.soscf import newton_ah
        if not isinstance(self, newton_ah._CIAH_SOSCF):
            return self
        return self.undo_soscf()

    def stability(self):
        raise NotImplementedError

    def nuc_grad_method(self):  # pragma: no cover
        '''Hook to create object for analytical nuclear gradients.'''
        raise NotImplementedError

    def update_(self, chkfile=None):
        '''Read attributes from the chkfile then replace the attributes of
        current object.  It's an alias of function update_from_chk_.
        '''
        from pyscf.scf import chkfile as chkmod
        if chkfile is None: chkfile = self.chkfile
        chk_scf = chkmod.load(chkfile, 'scf')
        nao = self.mol.nao
        mo = chk_scf['mo_coeff']
        if isinstance(mo, numpy.ndarray): # RHF
            mo_nao = mo.shape[-2]
        elif isinstance(mo[0], numpy.ndarray): # UHF
            mo_nao = mo[0].shape[-2]
        else: # KUHF
            mo_nao = mo[0][0].shape[-2]
        if mo_nao not in (nao, nao*2):
            logger.warn(self, 'Current mol is inconsistent with SCF object in '
                        'chkfile %s', chkfile)
        self.__dict__.update(chk_scf)
        return self
    update_from_chk = update_from_chk_ = update = update_

    as_scanner = as_scanner

    def reset(self, mol=None):
        '''Reset mol and relevant attributes associated to the old mol object'''
        if mol is not None:
            self.mol = mol
        self._opt = {None: None}
        self._eri = None
        self.scf_summary = {}
        return self

    def apply(self, fn, *args, **kwargs):
        if callable(fn):
            return lib.StreamObject.apply(self, fn, *args, **kwargs)
        elif isinstance(fn, str):
            from pyscf import mp, cc, ci, mcscf, tdscf
            for mod in (mp, cc, ci, mcscf, tdscf):
                method = getattr(mod, fn.upper(), None)
                if method is not None and callable(method):
                    if self.mo_coeff is None:
                        logger.warn(self, 'SCF object must be initialized '
                                    'before calling post-SCF methods.\n'
                                    'Initialize %s for %s', self, mod)
                        self.kernel()
                    return method(self, *args, **kwargs)
            raise ValueError('Unknown method %s' % fn)
        else:
            raise TypeError('First argument of .apply method must be a '
                            'function/class or a name (string) of a method.')

    def to_rhf(self):
        '''Convert the input mean-field object to a RHF/ROHF object.

        Note this conversion only changes the class of the mean-field object.
        The total energy and wave-function are the same as them in the input
        mean-field object.
        '''
        from pyscf.scf import addons
        return addons.convert_to_rhf(self)

    def to_uhf(self):
        '''Convert the input mean-field object to a UHF object.

        Note this conversion only changes the class of the mean-field object.
        The total energy and wave-function are the same as them in the input
        mean-field object.
        '''
        from pyscf.scf import addons
        return addons.convert_to_uhf(self)

    def to_ghf(self):
        '''Convert the input mean-field object to a GHF object.

        Note this conversion only changes the class of the mean-field object.
        The total energy and wave-function are the same as them in the input
        mean-field object.
        '''
        from pyscf.scf import addons
        return addons.convert_to_ghf(self)

    def to_rks(self, xc='HF'):
        '''Convert the input mean-field object to a RKS/ROKS object.

        Note this conversion only changes the class of the mean-field object.
        The total energy and wave-function are the same as them in the input
        mean-field object.
        '''
        return self.to_rhf().to_ks(xc)

    def to_uks(self, xc='HF'):
        '''Convert the input mean-field object to a UKS object.

        Note this conversion only changes the class of the mean-field object.
        The total energy and wave-function are the same as them in the input
        mean-field object.
        '''
        return self.to_uhf().to_ks(xc)

    def to_gks(self, xc='HF'):
        '''Convert the input mean-field object to a GKS object.

        Note this conversion only changes the class of the mean-field object.
        The total energy and wave-function are the same as them in the input
        mean-field object.
        '''
        return self.to_ghf().to_ks(xc)

    def convert_from_(self, mf):
        '''Convert the abinput mean-field object to the associated KS object.
        '''
        raise NotImplementedError

    def to_ks(self, xc='HF'):
        '''Convert the input mean-field object to the associated KS object.

        Note this conversion only changes the class of the mean-field object.
        The total energy and wave-function are the same as them in the input
        mean-field object.
        '''
        raise NotImplementedError

    def _transfer_attrs_(self, dst):
        '''This helper function transfers attributes from one SCF object to
        another SCF object. It is invoked by to_ks and to_hf methods.
        '''
        from pyscf.df.df_jk import _DFHF
        if isinstance(self, _DFHF) and not hasattr(dst, 'with_df'):
            # * Handle DF_SCF instances for to_xxx methods.
            # * Only the molecular SCF methods need to be explicitly converted.
            #   For PBC SCF methods, DF is enabled by default. calling density_fit()
            #   may alter the DF class. Conversion should be avoided here.
            dst = dst.density_fit(auxbasis=self.with_df.auxbasis)
        # Search for all tracked attributes, including those in base classes
        cls_keys = [getattr(cls, '_keys', ()) for cls in dst.__class__.__mro__[:-1]]
        dst_keys = set(dst.__dict__).union(*cls_keys)

        loc_dic = self.__dict__
        keys = set(loc_dic).intersection(dst_keys)
        dst.__dict__.update({k: loc_dic[k] for k in keys})
        dst.converged = False
        return dst

    def to_gpu(self):
        '''Converts to the object with GPU support.
        '''
        raise NotImplementedError

    def istype(self, type_code):
        '''
        Checks if the object is an instance of the class specified by the type_code.
        type_code can be a class or a str. If the type_code is a class, it is
        equivalent to the Python built-in function `isinstance`. If the type_code
        is a str, it checks the type_code against the names of the object and all
        its parent classes.
        '''
        if isinstance(type_code, type):
            # type_code is a class
            return isinstance(self, type_code)

        return any(type_code == t.__name__ for t in self.__class__.__mro__)


class KohnShamDFT:
    '''A mock DFT base class

    The base class KohnShamDFT is defined in the dft.rks module. This class can
    be used to verify if an SCF object is a Hartree-Fock method or a DFT method.
    It should be overwritten by the actual KohnShamDFT class when loading dft module.
    '''


class RHF(SCF):
    __doc__ = SCF.__doc__

    def check_sanity(self):
        mol = self.mol
        if mol.nelectron != 1 and mol.spin != 0:
            logger.warn(self, 'Invalid number of electrons %d for RHF method.',
                        mol.nelectron)
        return SCF.check_sanity(self)

    def get_init_guess(self, mol=None, key='minao', **kwargs):
        dm = SCF.get_init_guess(self, mol, key, **kwargs)
        if self.verbose >= logger.DEBUG1:
            s = self.get_ovlp()
            nelec = numpy.einsum('ij,ji', dm, s).real
            logger.debug1(self, 'Nelec from initial guess = %s', nelec)
        return dm

    @lib.with_doc(get_jk.__doc__)
    def get_jk(self, mol=None, dm=None, hermi=1, with_j=True, with_k=True,
               omega=None):
        # Note the incore version, which initializes an _eri array in memory.
        if mol is None: mol = self.mol
        if dm is None: dm = self.make_rdm1()
        if (not omega and
            (self._eri is not None or mol.incore_anyway or self._is_mem_enough())):
            if self._eri is None:
                self._eri = mol.intor('int2e', aosym='s8')
            vj, vk = dot_eri_dm(self._eri, dm, hermi, with_j, with_k)
        else:
            vj, vk = SCF.get_jk(self, mol, dm, hermi, with_j, with_k, omega)
        return vj, vk

    @lib.with_doc(get_veff.__doc__)
    def get_veff(self, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1):
        if mol is None: mol = self.mol
        if dm is None: dm = self.make_rdm1()
        if self._eri is not None or not self.direct_scf:
            vj, vk = self.get_jk(mol, dm, hermi)
            vhf = vj - vk * .5
        else:
            ddm = numpy.asarray(dm) - numpy.asarray(dm_last)
            vj, vk = self.get_jk(mol, ddm, hermi)
            vhf = vj - vk * .5
            vhf += numpy.asarray(vhf_last)
        return vhf

    def convert_from_(self, mf):
        '''Convert the input mean-field object to RHF/ROHF'''
        tgt = mf.to_rhf()
        self.__dict__.update(tgt.__dict__)
        return self

    def spin_square(self, mo_coeff=None, s=None):  # pragma: no cover
        '''Spin square and multiplicity of RHF determinant'''
        return 0, 1

    def stability(self,
                  internal=getattr(__config__, 'scf_stability_internal', True),
                  external=getattr(__config__, 'scf_stability_external', False),
                  verbose=None,
                  return_status=False,
                  **kwargs):
        '''
        RHF/RKS stability analysis.

        See also pyscf.scf.stability.rhf_stability function.

        Kwargs:
            internal : bool
                Internal stability, within the RHF optimization space.
            external : bool
                External stability. Including the RHF -> UHF and real -> complex
                stability analysis.
            return_status: bool
                Whether to return `stable_i` and `stable_e`

        Returns:
            If return_status is False (default), the return value includes
            two set of orbitals, which are more close to the stable condition.
            The first corresponds to the internal stability
            and the second corresponds to the external stability.

            Else, another two boolean variables (indicating current status:
            stable or unstable) are returned.
            The first corresponds to the internal stability
            and the second corresponds to the external stability.
        '''
        from pyscf.scf.stability import rhf_stability
        return rhf_stability(self, internal, external, verbose, return_status, **kwargs)

    def nuc_grad_method(self):
        from pyscf.grad import rhf
        return rhf.Gradients(self)

    def to_ks(self, xc='HF'):
        '''Convert to RKS object.
        '''
        from pyscf import dft
        return self._transfer_attrs_(dft.RKS(self.mol, xc=xc))

    # FIXME: consider the density_fit, x2c and soscf decoration
    to_gpu = lib.to_gpu

def _hf1e_scf(mf, *args):
    logger.info(mf, '\n')
    logger.info(mf, '******** 1 electron system ********')
    mf.converged = True
    h1e = mf.get_hcore(mf.mol)
    s1e = mf.get_ovlp(mf.mol)
    mf.mo_energy, mf.mo_coeff = mf.eig(h1e, s1e)
    mf.mo_occ = mf.get_occ(mf.mo_energy, mf.mo_coeff)
    mf.e_tot = mf.mo_energy[mf.mo_occ>0][0].real + mf.mol.energy_nuc()
    mf._finalize()
    return mf.e_tot


del (WITH_META_LOWDIN, PRE_ORTH_METHOD)


if __name__ == '__main__':
    from pyscf import scf
    mol = gto.Mole()
    mol.verbose = 5
    mol.output = None

    mol.atom = [['He', (0, 0, 0)], ]
    mol.basis = 'ccpvdz'
    mol.build(0, 0)

##############
# SCF result
    method = scf.RHF(mol).x2c().density_fit().newton()
    method.init_guess = '1e'
    energy = method.scf()
    print(energy)
